#!/usr/bin/env python3

import argparse
import datetime
import functools
import os
import pathlib
import re
import statistics
import subprocess
import sys
import tempfile

import git
import pandas
import plotly
import plotly.express
import tqdm

@functools.total_ordering
class Commit:
    """
    This class represents a commit inside a given Git repository.
    """

    def __init__(self, git_repo, sha):
        self._git_repo = git_repo
        self._sha = sha

    def __eq__(self, other):
        """
        Return whether two commits refer to the same commit.

        This doesn't take into account the content of the Git tree at those commits, only the
        'identity' of the commits themselves.
        """
        return self.fullrev == other.fullrev

    def __lt__(self, other):
        """
        Return whether a commit is an ancestor of another commit in the Git repository.
        """
        # Is self._sha an ancestor of other._sha?
        res = subprocess.run(['git', '-C', self._git_repo, 'merge-base', '--is-ancestor', self._sha, other._sha])
        if res.returncode not in (0, 1):
            raise RuntimeError(f'Error when trying to obtain the commit order for {self._sha} and {other._sha}')
        return res.returncode == 0

    def __hash__(self):
        """
        Return the full revision for this commit.
        """
        return hash(self.fullrev)

    @functools.cache
    def show(self, include_diff=False):
        """
        Return the commit information equivalent to `git show` associated to this commit.
        """
        cmd = ['git', '-C', self._git_repo, 'show', self._sha]
        if not include_diff:
            cmd.append('--no-patch')
        return subprocess.check_output(cmd, text=True)

    @functools.cached_property
    def shortrev(self):
        """
        Return the shortened version of the given SHA.
        """
        return subprocess.check_output(['git', '-C', self._git_repo, 'rev-parse', '--short', self._sha], text=True).strip()

    @functools.cached_property
    def fullrev(self):
        """
        Return the full SHA associated to this commit.
        """
        return subprocess.check_output(['git', '-C', self._git_repo, 'rev-parse', self._sha], text=True).strip()

    @functools.cached_property
    def commit_date(self):
        """
        Return the date of the commit as a `datetime.datetime` object.
        """
        repo = git.Repo(self._git_repo)
        return datetime.datetime.fromtimestamp(repo.commit(self._sha).committed_date)

    def prefetch(self):
        """
        Prefetch cached properties associated to this commit object.

        This makes it possible to control when time is spent recovering that information from Git for
        e.g. better reporting to the user.
        """
        self.commit_date
        self.fullrev
        self.shortrev
        self.show()

    def __str__(self):
        return self._sha

def truncate_lines(string, n, marker=None):
    """
    Truncate the given string at a certain number of lines.

    Optionally, add a marker on the last line to identify that truncation has happened.
    """
    lines = string.splitlines()
    truncated = lines[:n]
    if marker is not None and len(lines) > len(truncated):
        truncated[-1] = marker
    assert len(truncated) <= n, "broken post-condition"
    return '\n'.join(truncated)

def create_plot(data, metric, trendline=None, subtitle=None):
    """
    Create a plot object showing the evolution of each benchmark throughout the given commits for
    the given metric.
    """
    data = data.sort_values(by=['revlist_order', 'benchmark'])
    revlist = pandas.unique(data['commit']) # list of all commits in chronological order
    hover_info = {c: truncate_lines(c.show(), 30, marker='...').replace('\n', '<br>') for c in revlist}
    figure = plotly.express.scatter(data, title=f"{revlist[0].shortrev} to {revlist[-1].shortrev}",
                                          subtitle=subtitle,
                                          x='revlist_order', y=metric,
                                          symbol='benchmark',
                                          color='benchmark',
                                          hover_name=[hover_info[c] for c in data['commit']],
                                          trendline=trendline)
    return figure

def directory_path(string):
    if os.path.isdir(string):
        return pathlib.Path(string)
    else:
        raise NotADirectoryError(string)

def parse_lnt(lines, aggregate=statistics.median):
    """
    Parse lines in LNT format and return a list of dictionnaries of the form:

        [
            {
                'benchmark': <benchmark1>,
                <metric1>: float,
                <metric2>: float,
                ...
            },
            {
                'benchmark': <benchmark2>,
                <metric1>: float,
                <metric2>: float,
                ...
            },
            ...
        ]

    If a metric has multiple values associated to it, they are aggregated into a single
    value using the provided aggregation function.
    """
    results = {}
    for line in lines:
        line = line.strip()
        if not line:
            continue

        (identifier, value) = line.split(' ')
        (benchmark, metric) = identifier.split('.')
        if benchmark not in results:
            results[benchmark] = {'benchmark': benchmark}

        entry = results[benchmark]
        if metric not in entry:
            entry[metric] = []
        entry[metric].append(float(value))

    for (bm, entry) in results.items():
        for metric in entry:
            if isinstance(entry[metric], list):
                entry[metric] = aggregate(entry[metric])

    return list(results.values())

def sorted_revlist(git_repo, commits):
    """
    Return the list of commits sorted by their chronological order (from oldest to newest) in the
    provided Git repository. Items earlier in the list are older than items later in the list.
    """
    revlist_cmd = ['git', '-C', git_repo, 'rev-list', '--no-walk'] + list(commits)
    revlist = subprocess.check_output(revlist_cmd, text=True).strip().splitlines()
    return list(reversed(revlist))

def main(argv):
    parser = argparse.ArgumentParser(
        prog='visualize-historical',
        description='Visualize historical data in LNT format. This program generates a HTML file that embeds an '
                    'interactive plot with the provided data. The HTML file can then be opened in a browser to '
                    'visualize the data as a chart.',
        epilog='This script depends on the modules listed in `libcxx/utils/requirements.txt`.')
    parser.add_argument('directory', type=directory_path,
        help='Path to a valid directory containing benchmark data in LNT format, each file being named <commit>.lnt. '
             'This is also the format generated by the `benchmark-historical` utility.')
    parser.add_argument('--output', '-o', type=pathlib.Path, required=False,
        help='Optional path where to output the resulting HTML file. If it already exists, it is overwritten. '
             'Defaults to a temporary file which is opened automatically once generated, but not removed after '
             'creation.')
    parser.add_argument('--metric', type=str, default='execution_time',
        help='The metric to compare. LNT data may contain multiple metrics (e.g. code size, execution time, etc) -- '
             'this option allows selecting which metric is being visualized. The default is "execution_time".')
    parser.add_argument('--filter', type=str, required=False,
        help='An optional regular expression used to filter the benchmarks included in the chart. '
             'Only benchmarks whose names match the regular expression will be included. '
             'Since the chart is interactive, it generally makes most sense to include all the benchmarks '
             'and to then filter them in the browser, but in some cases producing a chart with a reduced '
             'number of data series is useful.')
    parser.add_argument('--subtitle', type=str, required=False,
        help='Optional subtitle for the chart. This can be used to help identify the contents of the chart.')
    parser.add_argument('--git-repo', type=directory_path, default=pathlib.Path(os.getcwd()),
        help='Path to the git repository to use for ordering commits in time. '
             'By default, the current working directory is used.')
    parser.add_argument('--open', action='store_true',
        help='Whether to automatically open the generated HTML file when finished. If no output file is provided, '
             'the resulting benchmark is opened automatically by default.')
    parser.add_argument('--trendline', type=str, required=False, default=None, choices=('ols', 'lowess', 'expanding'),
        help='Optional trendline to add on each series in the chart. See the documentation in '
             'https://plotly.com/python-api-reference/generated/plotly.express.trendline_functions.html '
             'details on each option.')
    args = parser.parse_args(argv)

    # Extract benchmark data from the directory.
    data = {}
    files = [f for f in args.directory.glob('*.lnt')]
    for file in tqdm.tqdm(files, desc='Parsing LNT files'):
        rows = parse_lnt(file.read_text().splitlines())
        (commit, _) = os.path.splitext(os.path.basename(file))
        commit = Commit(args.git_repo, commit)
        data[commit] = rows

    # Obtain commit information which is then cached throughout the program. Do this
    # eagerly so we can provide a progress bar.
    for commit in tqdm.tqdm(data.keys(), desc='Prefetching Git information'):
        commit.prefetch()

    # Create a dataframe from the raw data and add some columns to it:
    # - 'commit' represents the Commit object associated to the results in that row
    # - `revlist_order` represents the order of the commit within the Git repository.
    # - `date` represents the commit date
    revlist = sorted_revlist(args.git_repo, [c.fullrev for c in data.keys()])
    data = pandas.DataFrame([row | {'commit': c} for (c, rows) in data.items() for row in rows])
    data = data.join(pandas.DataFrame([{'revlist_order': revlist.index(c.fullrev)} for c in data['commit']]))
    data = data.join(pandas.DataFrame([{'date': c.commit_date} for c in data['commit']]))

    # Filter the benchmarks if needed.
    if args.filter is not None:
        keeplist = [b for b in data['benchmark'] if re.search(args.filter, b) is not None]
        data = data[data['benchmark'].isin(keeplist)]
        if len(data) == 0:
            raise RuntimeError(f'Filter "{args.filter}" resulted in empty data set -- nothing to plot')

    # Plot the data for all the required benchmarks.
    figure = create_plot(data, args.metric, trendline=args.trendline, subtitle=args.subtitle)
    do_open = args.output is None or args.open
    output = args.output if args.output is not None else tempfile.NamedTemporaryFile(suffix='.html').name
    plotly.io.write_html(figure, file=output, auto_open=do_open)

if __name__ == '__main__':
    main(sys.argv[1:])
